{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding an Environment \n",
    "\n",
    "Adding your custom environments to Coach will allow you to solve your own tasks using any of the predefined algorithms. There are two ways for adding your own environment to Coach:\n",
    "1. Implementing your environment as an OpenAI Gym environment\n",
    "2. Implementing a wrapper for your environment in Coach\n",
    "\n",
    "In this tutorial, we'll follow the 2nd option, and add the DeepMind Control Suite environment to Coach. We will then create a preset that trains a DDPG agent on one of the levels of the new environment."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "First, we will need to install the DeepMind Control Suite library. To do this, follow the installation instructions here: https://github.com/deepmind/dm_control#installation-and-requirements. \n",
    "\n",
    "\n",
    "Make sure your ```LD_LIBRARY_PATH``` contains the path to the GLEW and LGFW libraries (https://github.com/openai/mujoco-py/issues/110).\n",
    "\n",
    "\n",
    "In addition, Mujoco rendering might need to be disabled (https://github.com/deepmind/dm_control/issues/20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Environment Wrapper\n",
    "\n",
    "To integrate an environment with Coach, we need to implement an environment wrapper. Coach has several predefined environment wrappers which are placed under the environments folder, but we can place our new environment wherever we want and reference it later.\n",
    "\n",
    "Now let's define the control suite's environment wrapper class.\n",
    "\n",
    "In the ```__init__``` function we'll load and initialize the simulator using the level given by `self.env_id`.\n",
    "Additionally, we will define the state space and action space of the environment, through the `self.state_space` and `self.action_space` members.\n",
    "In this case, the state space is a dictionary consisting of 2 observations:\n",
    "* **'pixels'** - the image received from the mujoco camera, defined as an ImageObservationSpace.\n",
    "* **'measurements'** - the joint measurements of the model, defined as a VectorObservationSpace.\n",
    "The action space is a continuous space defined by the BoxActionSpace."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "from typing import Union\n",
    "\n",
    "from dm_control import suite\n",
    "from dm_control.suite.wrappers import pixels\n",
    "\n",
    "from rl_coach.base_parameters import VisualizationParameters\n",
    "from rl_coach.spaces import BoxActionSpace, ImageObservationSpace, VectorObservationSpace, StateSpace\n",
    "from rl_coach.environments.environment import Environment, LevelSelection\n",
    "\n",
    "\n",
    "# Environment\n",
    "class ControlSuiteEnvironment(Environment):\n",
    "    def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,\n",
    "                 seed: Union[None, int]=None, human_control: bool=False,\n",
    "                 custom_reward_threshold: Union[int, float]=None, **kwargs):\n",
    "        super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters)\n",
    "        \n",
    "        # load and initialize environment\n",
    "        domain_name, task_name = self.env_id.split(\":\")\n",
    "        self.env = suite.load(domain_name=domain_name, task_name=task_name)\n",
    "        self.env = pixels.Wrapper(self.env, pixels_only=False)\n",
    "\n",
    "        # seed\n",
    "        if self.seed is not None:\n",
    "            np.random.seed(self.seed)\n",
    "            random.seed(self.seed)\n",
    "\n",
    "        self.state_space = StateSpace({})\n",
    "\n",
    "        # image observations\n",
    "        self.state_space['pixels'] = ImageObservationSpace(shape=self.env.observation_spec()['pixels'].shape,\n",
    "                                                           high=255)\n",
    "\n",
    "        # measurements observations\n",
    "        measurements_space_size = 0\n",
    "        measurements_names = []\n",
    "        for observation_space_name, observation_space in self.env.observation_spec().items():\n",
    "            if len(observation_space.shape) == 0:\n",
    "                measurements_space_size += 1\n",
    "                measurements_names.append(observation_space_name)\n",
    "            elif len(observation_space.shape) == 1:\n",
    "                measurements_space_size += observation_space.shape[0]\n",
    "                measurements_names.extend([\"{}_{}\".format(observation_space_name, i) for i in\n",
    "                                            range(observation_space.shape[0])])\n",
    "        self.state_space['measurements'] = VectorObservationSpace(shape=measurements_space_size,\n",
    "                                                                  measurements_names=measurements_names)\n",
    "\n",
    "        # actions\n",
    "        self.action_space = BoxActionSpace(\n",
    "            shape=self.env.action_spec().shape[0],\n",
    "            low=self.env.action_spec().minimum,\n",
    "            high=self.env.action_spec().maximum\n",
    "        )\n",
    "\n",
    "        # initialize the state by getting a new state from the environment\n",
    "        self.reset_internal_state(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following functions cover the API expected from a new environment wrapper:\n",
    "\n",
    "1. ```_update_state``` - update the internal state of the wrapper (to be queried by the agent),\n",
    "   which consists of:\n",
    "   * `self.state` - a dictionary containing all the observations from the environment and which follows the state space definition.\n",
    "   * `self.reward` - a float value containing the reward for the last step of the environment\n",
    "   * `self.done` - a boolean flag which signals if the environment episode has ended\n",
    "   * `self.goal` - a numpy array representing the goal the environment has set for the last step\n",
    "   * `self.info` - a dictionary that contains any additional information for the last step\n",
    "   \n",
    "2. ```_take_action``` - gets the action from the agent, and make a single step on the environment\n",
    "3. ```_restart_environment_episode``` - restart the environment on a new episode \n",
    "4. ```get_rendered_image``` - get a rendered image of the environment in its current state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ControlSuiteEnvironment(Environment):\n",
    "    def _update_state(self):\n",
    "        self.state = {}\n",
    "\n",
    "        self.pixels = self.last_result.observation['pixels']\n",
    "        self.state['pixels'] = self.pixels\n",
    "\n",
    "        self.measurements = np.array([])\n",
    "        for sub_observation in self.last_result.observation.values():\n",
    "            if isinstance(sub_observation, np.ndarray) and len(sub_observation.shape) == 1:\n",
    "                self.measurements = np.concatenate((self.measurements, sub_observation))\n",
    "            else:\n",
    "                self.measurements = np.concatenate((self.measurements, np.array([sub_observation])))\n",
    "        self.state['measurements'] = self.measurements\n",
    "\n",
    "        self.reward = self.last_result.reward if self.last_result.reward is not None else 0\n",
    "\n",
    "        self.done = self.last_result.last()\n",
    "\n",
    "    def _take_action(self, action):\n",
    "        if type(self.action_space) == BoxActionSpace:\n",
    "            action = self.action_space.clip_action_to_space(action)\n",
    "\n",
    "        self.last_result = self.env.step(action)\n",
    "\n",
    "    def _restart_environment_episode(self, force_environment_reset=False):\n",
    "        self.last_result = self.env.reset()\n",
    "\n",
    "    def get_rendered_image(self):\n",
    "        return self.env.physics.render(camera_id=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we will need to define a parameters class corresponding to our environment class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.environments.environment import EnvironmentParameters\n",
    "from rl_coach.filters.filter import NoInputFilter, NoOutputFilter\n",
    "\n",
    "# Parameters\n",
    "class ControlSuiteEnvironmentParameters(EnvironmentParameters):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.default_input_filter = NoInputFilter()\n",
    "        self.default_output_filter = NoOutputFilter()\n",
    "\n",
    "    @property\n",
    "    def path(self):\n",
    "        return 'environments.control_suite_environment:ControlSuiteEnvironment'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Preset\n",
    "\n",
    "Now that we have our new environment, we will want to use one of the predefined algorithms to try and solve it.\n",
    "In this case, since the environment defines a continuous action space, we will want to use a supporting algorithm, so we will select DDPG. To run DDPG on the environment, we will need to define a preset for it.\n",
    "The new preset will typically be defined in a new file - ```presets\\ControlSuite_DDPG.py```. \n",
    "\n",
    "First - let's define the agent parameters. We can use the default parameters for the DDPG agent, except that we need to update the networks input embedders to point to the correct environment observation. When we defined the environment, we set it to have 2 observations - 'pixels' and 'measurements'. In this case, we will want to learn only from the measurements, so we will need to modify the default input embedders to point to 'measurements' instead of the default 'observation' defined in `DDPGAgentParameters`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.agents.ddpg_agent import DDPGAgentParameters\n",
    "\n",
    "\n",
    "agent_params = DDPGAgentParameters()\n",
    "# rename the input embedder key from 'observation' to 'measurements'\n",
    "agent_params.network_wrappers['actor'].input_embedders_parameters['measurements'] = \\\n",
    "    agent_params.network_wrappers['actor'].input_embedders_parameters.pop('observation')\n",
    "agent_params.network_wrappers['critic'].input_embedders_parameters['measurements'] = \\\n",
    "    agent_params.network_wrappers['critic'].input_embedders_parameters.pop('observation')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's define the environment parameters. The DeepMind Control Suite environment has many levels to select from. The level can be selected either as a specific level name, for example 'cartpole:swingup', or by a list of level names from which a single level should be selected. The later can be done using the `SingleLevelSelection` class, and then the level can be selected from the command line using the `-lvl` flag."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.environments.control_suite_environment import ControlSuiteEnvironmentParameters, control_suite_envs\n",
    "from rl_coach.environments.environment import SingleLevelSelection\n",
    "\n",
    "env_params = ControlSuiteEnvironmentParameters(level='cartpole:balance')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will also need to define a schedule for the training. The schedule defines the number of steps we want run our experiment for and when to evaluate the trained model. In this case, we will use a simple predefined schedule, and just add some heatup steps to fill up the agent memory buffers with initial data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.graph_managers.graph_manager import SimpleSchedule\n",
    "from rl_coach.core_types import EnvironmentSteps\n",
    "\n",
    "schedule_params = SimpleSchedule()\n",
    "schedule_params.heatup_steps = EnvironmentSteps(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will also want to see the simulator in action (otherwise we will miss all the fun), so let's set the `render` flag to True in the visualization parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.base_parameters import VisualizationParameters\n",
    "\n",
    "vis_params = VisualizationParameters(render=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we'll create and run the graph manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager\n",
    "\n",
    "graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,\n",
    "                                    schedule_params=schedule_params, vis_params=vis_params)\n",
    "\n",
    "# let the adventure begin\n",
    "graph_manager.improve()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
